{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYljlovDUEqY"
      },
      "source": [
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rSDyLxJWGweX"
      },
      "source": [
        "## Synthetic experiment that sweeps over a range of distribution shifts\n",
        "\n",
        "This notebook implements an experiment that trains a model on a single source domain where $P(U=1)=0.1$ and evaluates on a collection of target domains where $P(U=1)=\\{0.1, \\ldots, 0.9\\}$. The notebook also produces a visualuzation of the results, analogous to Figure 3A in the paper. This notebook relies on previously running `colab/synthetic_data_to_file.ipynb`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IMDdsYh6NEyi"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sklearn\n",
        "import tensorflow as tf\n",
        "import ml_collections as mlc\n",
        "import scipy\n",
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "import re\n",
        "import os\n",
        "import itertools\n",
        "from sklearn.preprocessing import OneHotEncoder\n",
        "from sklearn.metrics import roc_auc_score\n",
        "from sklearn.metrics import accuracy_score\n",
        "from sklearn.metrics import log_loss\n",
        "from IPython.display import display\n",
        "from latent_shift_adaptation.methods.algorithms_sknp import get_classifier, latent_shift_adaptation\n",
        "from latent_shift_adaptation.utils import gumbelmax_vae_ci, gumbelmax_vae\n",
        "from latent_shift_adaptation.methods.vae import gumbelmax_vanilla, gumbelmax_graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RkSTKQyBUEqe"
      },
      "outputs": [],
      "source": [
        "ITERATIONS = 10 # Set to 10 to replicate experiments in paper\n",
        "EPOCHS = 200 # Set to 200 to replicate experiments in paper"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i2Fd4jKSNfM8"
      },
      "outputs": [],
      "source": [
        "#@title Library functions\n",
        "\n",
        "def extract_from_df(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', \n",
        "                                      'u_one_hot', 'x_scaled',\n",
        "                                      'w_1', 'w_1_binary', 'w_1_one_hot',\n",
        "                                      'w_2', 'w_2_binary', 'w_2_one_hot',\n",
        "                                      'w_2_binary', 'w_2_one_hot',\n",
        "                                      ]):\n",
        "  \"\"\"\n",
        "  Extracts dict of numpy arrays from dataframe\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  for col in cols:\n",
        "    if col in samples_df.columns:\n",
        "      result[col] = samples_df[col].values\n",
        "    else:\n",
        "      match_str = f\"^{col}_\\\\d$\"\n",
        "      r = re.compile(match_str, re.IGNORECASE)\n",
        "      matching_columns = list(filter(r.match, samples_df.columns))\n",
        "      if len(matching_columns) == 0:\n",
        "        continue\n",
        "      result[col] = samples_df[matching_columns].to_numpy()\n",
        "  return result\n",
        "\n",
        "def extract_from_df_nested(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):\n",
        "  \"\"\"\n",
        "  Extracts nested dict of numpy arrays from dataframe with structure {domain: {partition: data}}\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  if 'domain' in samples_df.keys():\n",
        "    for domain in samples_df['domain'].unique():\n",
        "      result[domain] = {}\n",
        "      domain_df = samples_df.query('domain == @domain')\n",
        "      for partition in domain_df['partition'].unique():\n",
        "        partition_df = domain_df.query('partition == @partition')\n",
        "        result[domain][partition] = extract_from_df(partition_df, cols=cols)\n",
        "  else:\n",
        "    for partition in samples_df['partition'].unique():\n",
        "        partition_df = samples_df.query('partition == @partition')\n",
        "        result[partition] = extract_from_df(partition_df, cols=cols)\n",
        "  return result\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4kVfS6oTq4z0"
      },
      "outputs": [],
      "source": [
        "# Read the data\n",
        "folder_id='./tmp_data'\n",
        "p_u_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n",
        "w_coeff_list = [1, 2, 3]\n",
        "data_dict_all = {}\n",
        "for p_u_0, w_coeff in itertools.product(p_u_list, w_coeff_list):\n",
        "  print(p_u_0, w_coeff)\n",
        "  filename=f\"synthetic_multivariate_num_samples_10000_w_coeff_{w_coeff}_p_u_0_{p_u_0}.csv\"\n",
        "  data_df = pd.read_csv(os.path.join(folder_id, filename))\n",
        "  data_dict_all[(p_u_0, w_coeff)] = extract_from_df_nested(data_df)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fjyHcFNmn3Rh"
      },
      "outputs": [],
      "source": [
        "# Define Sklearn evaluation functions\n",
        "def soft_accuracy(y_true, y_pred, threshold=0.5, **kwargs):\n",
        "  return sklearn.metrics.accuracy_score(y_true, y_pred \u003e= threshold, **kwargs)\n",
        "\n",
        "def log_loss64(y_true, y_pred, **kwargs):\n",
        "  return sklearn.metrics.log_loss(y_true, y_pred.astype(np.float64), **kwargs)\n",
        "\n",
        "evals_sklearn = {\n",
        "    'cross-entropy': log_loss64,\n",
        "    'accuracy': soft_accuracy, \n",
        "    'auc': sklearn.metrics.roc_auc_score\n",
        "}\n",
        "def fit_and_evaluate_sk(data_dict_source: dict,\n",
        "                     data_dict_target: dict,\n",
        "                     model_type: str = 'mlp'\n",
        "                     ):\n",
        "  \n",
        "  ## Fit baselines\n",
        "  model = get_classifier(model_type)\n",
        "  model.fit(data_dict_source['train']['x'], data_dict_source['train']['y'])\n",
        "  model_target = get_classifier(model_type)\n",
        "  model_target.fit(data_dict_target['train']['x'], data_dict_target['train']['y'])\n",
        "  \n",
        "  # Apply LSA with known U\n",
        "  lsa_pred_probs_target = latent_shift_adaptation(\n",
        "      x_source=data_dict_source['train']['x'],\n",
        "      y_source=data_dict_source['train']['y'],\n",
        "      u_source=data_dict_source['train']['u'],\n",
        "      x_target=data_dict_target['test']['x'],\n",
        "      model_type=model_type)[:, -1]\n",
        "  \n",
        "  result = {}\n",
        "  for metric, eval_fn in evals_sklearn.items():\n",
        "    result[('erm-source', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model.predict_proba(data_dict_source['test']['x'])[:, -1])\n",
        "    result[('erm-source', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model.predict_proba(data_dict_source['test']['x'])[:, -1])\n",
        "    result[('erm-source', 'target', metric)] = eval_fn(data_dict_target['test']['y'], model.predict_proba(data_dict_target['test']['x'])[:, -1])\n",
        "    result[('erm-target', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model_target.predict_proba(data_dict_source['test']['x'])[:, -1])\n",
        "    result[('erm-target', 'target', metric)] = eval_fn(data_dict_target['test']['y'], model_target.predict_proba(data_dict_target['test']['x'])[:, -1])\n",
        "    result[('lsa-oracle-sk', 'target', metric)] = eval_fn(data_dict_target['test']['y'], lsa_pred_probs_target)\n",
        "\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tCjTjhrCPmFB"
      },
      "outputs": [],
      "source": [
        "data_dict_source = data_dict_all[(0.9, 1)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7w5eKbcttixo"
      },
      "outputs": [],
      "source": [
        "result_list_sk = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  np.random.seed(seed)\n",
        "  result_sk = {}\n",
        "  for p_u_0 in p_u_list:\n",
        "    print(p_u_0)\n",
        "    result_sk[p_u_0] = fit_and_evaluate_sk(data_dict_source, data_dict_all[(p_u_0, 1)], model_type='mlp')\n",
        "  result_list_sk.append(result_sk)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zTqganmOBtL5"
      },
      "outputs": [],
      "source": [
        "result_sk_concat = pd.concat([pd.concat({key: pd.Series(value) for key, value in elem.items()}).to_frame().rename_axis(['p_u_target_0', 'method', 'eval_set', 'metric']).rename(columns={0: 'performance'}).assign(iteration=i) for i, elem in enumerate(result_list_sk)]).reset_index()\n",
        "result_sk_concat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vLiVsCl_vzSx"
      },
      "outputs": [],
      "source": [
        "result_sk_concat.query('method == \"erm-source\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\" \u0026 p_u_target_0 == 0.1')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JEg0wqamxgSC"
      },
      "source": [
        "### Run VAE"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jqg2WTo9x3XE"
      },
      "outputs": [],
      "source": [
        "# Create TF datasets\n",
        "\n",
        "batch_size = 128\n",
        "\n",
        "ds_dict_source = {\n",
        "    w_coeff: {\n",
        "        key: tf.data.Dataset.from_tensor_slices(\n",
        "        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), \n",
        "    ).repeat().shuffle(1000).batch(batch_size) for key, value in data_dict_all[(0.9, w_coeff)].items()\n",
        "  } for w_coeff in w_coeff_list\n",
        "}\n",
        "\n",
        "ds_dict_target = {\n",
        "    p_u_0: {\n",
        "        key: tf.data.Dataset.from_tensor_slices(\n",
        "        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), \n",
        "    ).repeat().shuffle(1000).batch(batch_size) for key, value in data_dict_all[(p_u_0, 1)].items()\n",
        "    } for p_u_0 in p_u_list\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q8vENe2WzQkX"
      },
      "outputs": [],
      "source": [
        "ds_temp = ds_dict_source[1]\n",
        "batch = next(iter(ds_temp['train']))\n",
        "x_dim = batch[0].shape[1]\n",
        "c_dim = batch[2].shape[1]\n",
        "w_dim = batch[3].shape[1]\n",
        "u_dim = batch[4].shape[1]\n",
        "num_classes = 2\n",
        "test_fract = 0.2\n",
        "val_fract = 0.1\n",
        "\n",
        "num_examples = 10_000\n",
        "steps_per_epoch = num_examples // batch_size\n",
        "steps_per_epoch_test = int(steps_per_epoch * test_fract)\n",
        "steps_per_epoch_val = int(steps_per_epoch * val_fract)\n",
        "steps_per_epoch_train = steps_per_epoch - steps_per_epoch_test - steps_per_epoch_val\n",
        "\n",
        "pos = mlc.ConfigDict()\n",
        "pos.x, pos.y, pos.c, pos.w, pos.u = 0, 1, 2, 3, 4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y4OgX_m80MO2"
      },
      "outputs": [],
      "source": [
        "def evaluate_clf(data_dict_source, data_dict_target):\n",
        "  result_dict = {}\n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric] = {}\n",
        "  \n",
        "  y_pred_source = clf.predict(data_dict_source['test']['x'])\n",
        "  y_pred_target = clf.predict(data_dict_target['test']['x'])\n",
        "  if 'cbm' in method:\n",
        "    # hacky workaround for now\n",
        "    y_pred_source = y_pred_source[1]\n",
        "    y_pred_target = y_pred_target[1]\n",
        "  y_pred_source = y_pred_source.numpy()[:, 1] if tf.is_tensor(y_pred_source) else y_pred_source[:, 1]\n",
        "  y_pred_target = y_pred_target.numpy()[:, 1] if tf.is_tensor(y_pred_target) else y_pred_target[:, 1]\n",
        "  y_true_source = data_dict_source['test']['y']\n",
        "  y_true_target = data_dict_target['test']['y']\n",
        "  \n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric]['source'] = evals_sklearn[metric](y_true_source, y_pred_source)\n",
        "    result_dict[metric]['target'] = evals_sklearn[metric](y_true_target, y_pred_target)\n",
        "  return result_dict"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xaICaXV5wyGL"
      },
      "outputs": [],
      "source": [
        "DEFAULT_LOSS = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "def mlp(num_classes, width, input_shape, learning_rate,\n",
        "        loss=DEFAULT_LOSS, metrics=[]):\n",
        "  \"\"\"Multilabel Classification.\"\"\"\n",
        "  model_input = tf.keras.Input(shape=input_shape)\n",
        "  # hidden layer\n",
        "  if width:\n",
        "    x = tf.keras.layers.Dense(\n",
        "        width, use_bias=True, activation='relu'\n",
        "    )(model_input)\n",
        "  else:\n",
        "    x = model_input\n",
        "  model_outuput = tf.keras.layers.Dense(num_classes,\n",
        "                                        use_bias=True,\n",
        "                                        activation=\"linear\")(x)  # get logits\n",
        "  opt = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)\n",
        "  model = tf.keras.models.Model(model_input, model_outuput)\n",
        "  model.build(input_shape)\n",
        "  model.compile(loss=loss, optimizer=opt, metrics=metrics)\n",
        "\n",
        "  return model\n",
        "\n",
        "xlabel = 'x'  # or 'x', 'x_scaled'\n",
        "SEED = 0\n",
        "tf.random.set_seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "evals = {  # evaluation functions\n",
        "    \"cross-entropy\": tf.keras.metrics.CategoricalCrossentropy(),\n",
        "    \"accuracy\": tf.keras.metrics.CategoricalAccuracy(),\n",
        "    \"auc\": tf.keras.metrics.AUC(multi_label = False)\n",
        "}\n",
        "reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "    monitor='loss', min_delta=0.01, factor=0.1, patience=20,\n",
        "    min_lr=1e-7)\n",
        "\n",
        "callbacks = [reduce_lr]\n",
        "\n",
        "do_calib = True\n",
        "evaluate = tf.keras.metrics.CategoricalCrossentropy()\n",
        "\n",
        "learning_rate = 0.01  #@param {type:\"number\"}\n",
        "width = 100  #@param {type:\"number\"}\n",
        "\n",
        "epochs = EPOCHS\n",
        "train_kwargs = {\n",
        "    'epochs': epochs,\n",
        "    'steps_per_epoch':steps_per_epoch_train,\n",
        "    'verbose': True,\n",
        "    'callbacks':callbacks\n",
        "    }\n",
        "\n",
        "val_kwargs = {\n",
        "    'epochs': epochs,\n",
        "    'steps_per_epoch':steps_per_epoch_val,\n",
        "    'verbose': False,\n",
        "    'callbacks':callbacks\n",
        "    }\n",
        "\n",
        "test_kwargs = {'verbose': False,\n",
        "               'steps': steps_per_epoch_test}\n",
        "tmep_kwargs = {'verbose': False}\n",
        "latent_dim = 10"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SF6SJznh6_ur"
      },
      "outputs": [],
      "source": [
        "method = 'vae_graph'\n",
        "input_shape = (x_dim, )\n",
        "result_list_vae = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  result_vae = {}\n",
        "  ds_target_dummy = ds_dict_target[0.9]\n",
        "  for w_coeff in w_coeff_list:\n",
        "    print(f'Training model for w_coeff: {w_coeff}')\n",
        "    ds_source = ds_dict_source[w_coeff]\n",
        "    encoder = mlp(num_classes=latent_dim, width=width,\n",
        "                  input_shape=(x_dim + c_dim + w_dim + num_classes,),\n",
        "                  learning_rate=learning_rate,\n",
        "                  metrics=['accuracy'])\n",
        "\n",
        "    model_x2u = mlp(num_classes=latent_dim, width=width, input_shape=(x_dim,),\n",
        "                    learning_rate=learning_rate,\n",
        "                    metrics=['accuracy'])\n",
        "    model_xu2y = mlp(num_classes=num_classes, width=width,\n",
        "                    input_shape=(x_dim + latent_dim,),\n",
        "                    learning_rate=learning_rate,\n",
        "                    metrics=['accuracy'])\n",
        "    vae_opt = tf.keras.optimizers.RMSprop(learning_rate=1e-4)\n",
        "\n",
        "    dims = mlc.ConfigDict()\n",
        "    dims.x = x_dim\n",
        "    dims.y = num_classes\n",
        "    dims.c = c_dim\n",
        "    dims.w = u_dim\n",
        "    dims.u = u_dim\n",
        "\n",
        "    clf = gumbelmax_graph.Method(encoder, width, vae_opt,\n",
        "                        model_x2u, model_xu2y, \n",
        "                        dims, latent_dim, None,\n",
        "                        kl_loss_coef=3,\n",
        "                        num_classes=num_classes, evaluate=evaluate,\n",
        "                        dtype=tf.float32, pos=pos)\n",
        "\n",
        "    clf.fit(ds_source['train'], ds_source['val'], ds_target_dummy['train'],\n",
        "            steps_per_epoch_val, **train_kwargs)\n",
        "    for p_u_0 in p_u_list:\n",
        "      print(f'Adapting for p_u_0: {p_u_0}')\n",
        "      ds_target = ds_dict_target[p_u_0]\n",
        "      clf.freq_ratio = clf._get_freq_ratio(\n",
        "          data_source_val=ds_source['val'], \n",
        "          data_target=ds_target['train'], \n",
        "          num_batches=steps_per_epoch_val\n",
        "        )\n",
        "      result_vae[(p_u_0, w_coeff)] = evaluate_clf(data_dict_source, data_dict_all[(p_u_0, 1)])\n",
        "  result_list_vae.append(result_vae)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SO5T2G3_2fxb"
      },
      "outputs": [],
      "source": [
        "result_vae_concat = pd.concat([\n",
        "    (\n",
        "      pd.concat({key: pd.DataFrame(value) for key, value in elem.items()})\n",
        "      .rename_axis(['p_u_target_0', 'w_coeff', 'eval_set'])\n",
        "      .reset_index()\n",
        "      .melt(id_vars=['p_u_target_0', 'w_coeff', 'eval_set'], value_vars=['cross-entropy', 'accuracy', 'auc'], var_name='metric', value_name='performance')\n",
        "  ).assign(iteration=i) for i, elem in enumerate(result_list_vae)\n",
        "])\n",
        "result_vae_concat"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WiBBRPYI8LbI"
      },
      "outputs": [],
      "source": [
        "result_vae_concat.query('eval_set == \"target\" \u0026 metric == \"auc\"')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vIdqQ3OjTCTA"
      },
      "outputs": [],
      "source": [
        "result_sk_agg = result_sk_concat.groupby(['p_u_target_0', 'method', 'eval_set', 'metric']).agg(performance_mean=('performance', 'mean'), performance_std=('performance', 'std')).reset_index()\n",
        "result_vae_agg = result_vae_concat.groupby(['p_u_target_0', 'w_coeff', 'eval_set', 'metric']).agg(performance_mean=('performance', 'mean'), performance_std=('performance', 'std')).reset_index().assign(method='vae')\n",
        "result_agg = pd.concat([result_sk_agg, result_vae_agg])\n",
        "result_agg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hbppz04ZTp6y"
      },
      "outputs": [],
      "source": [
        "## Make a plot\n",
        "plt.close()\n",
        "plt.figure(figsize=(6, 4))\n",
        "plot_x = 1-np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])\n",
        "erm_source_mean = result_sk_agg.query('method == \"erm-source\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "erm_source_std = result_sk_agg.query('method == \"erm-source\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, erm_source_mean, yerr=erm_source_std, color='k', linestyle='dashed', lw=2)\n",
        "\n",
        "erm_target_mean = result_sk_agg.query('method == \"erm-target\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "erm_target_std = result_sk_agg.query('method == \"erm-target\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, erm_target_mean, yerr=erm_target_std, color='k', linestyle='dashed', lw=2)\n",
        "\n",
        "\n",
        "lsa_result_mean = result_sk_agg.query('method == \"lsa-oracle-sk\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "lsa_result_std = result_sk_agg.query('method == \"lsa-oracle-sk\" \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, lsa_result_mean, yerr=lsa_result_std, color='r', lw=2, label='LSA-observed')\n",
        "\n",
        "vae_result_1_mean = result_vae_agg.query('w_coeff == 1 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "vae_result_1_std = result_vae_agg.query('w_coeff == 1 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, vae_result_1_mean, yerr=vae_result_1_std, color=\"#154360\", lw=2, label='WAE-high-noise')\n",
        "vae_result_2_mean = result_vae_agg.query('w_coeff == 2 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "vae_result_2_std = result_vae_agg.query('w_coeff == 2 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, vae_result_2_mean, yerr=vae_result_2_std, color=\"#2E86C1\", lw=2, label='WAE-med-noise')\n",
        "vae_result_3_mean = result_vae_agg.query('w_coeff == 3 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_mean.values\n",
        "vae_result_3_std = result_vae_agg.query('w_coeff ==3 \u0026 eval_set == \"target\" \u0026 metric == \"auc\"').sort_values('p_u_target_0').performance_std.values\n",
        "plt.errorbar(plot_x, vae_result_3_mean, yerr=vae_result_3_std, color=\"#85C1E9\", lw=2, label='WAE-low-noise')\n",
        "\n",
        "plt.ylabel('AUROC', size=20)\n",
        "plt.xlabel('q(U=1)', size=20)\n",
        "plt.legend(fontsize=14)\n",
        "plt.xticks(fontsize=14)\n",
        "plt.yticks(fontsize=14)\n",
        "plt.text(0.95, np.array(erm_source_mean).min(), s='ERM-source', fontsize=16)\n",
        "plt.text(0.95, erm_target_mean[0], s='ERM-target', fontsize=16)\n",
        "sns.despine()\n",
        "\n",
        "figure_folder_id = './tmp_data'\n",
        "figure_filename = 'synthetic_sweep_vae_200_epochs.png'\n",
        "plt.savefig(os.path.join(figure_folder_id, figure_filename), bbox_inches='tight', dpi=90)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "723t4Mndl0pE"
      },
      "outputs": [],
      "source": [
        "# Write the plot data\n",
        "result_agg.to_csv(os.path.join(figure_folder_id, 'synthetic_sweep_vae_results_200_epochs.csv'), index=False)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
